{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pwd\n",
    "# !pip install -e .\n",
    "# !npm install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/work/paras/representjs/representjs')\n",
    "\n",
    "import torch\n",
    "from pathlib import Path\n",
    "from models.code_mlm import CodeMLM\n",
    "import sentencepiece as spm\n",
    "\n",
    "sp = spm.SentencePieceProcessor()\n",
    "sp.Load(\"data/codesearchnet_javascript/csnjs_8k_9995p_unigram_url.model\")\n",
    "pad_id = sp.PieceToId(\"[PAD]\")\n",
    "mask_id = sp.PieceToId(\"[MASK]\")\n",
    "\n",
    "ckpt_file = Path(\"data/runs/22006_roberta_no_weight_decay_1590101821/ckpt_pretrain_ep0003_step0050000.pth\").resolve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CodeMLM(sp.GetPieceSize(), pad_id=pad_id).cuda()\n",
    "model.load_state_dict(torch.load(ckpt_file)['model_state_dict'])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = \"\"\"const x = function (z) {\n",
    "    var x = 1;\n",
    "    for (var i = 0; i < 10; i++) {\n",
    "        x += i;\n",
    "    }\n",
    "    return z + x;\n",
    "}\"\"\"\n",
    "\n",
    "mask_pos = 14\n",
    "seq = [sp.PieceToId(\"<s>\")] + sp.EncodeAsIds(string) + [sp.PieceToId(\"</s>\")]\n",
    "masked_seq = seq[:]\n",
    "masked_seq[mask_pos] = mask_id\n",
    "\n",
    "print(sp.DecodeIds(seq))\n",
    "print(sp.DecodeIds(masked_seq))\n",
    "\n",
    "with torch.no_grad():\n",
    "    logits = model(torch.LongTensor(masked_seq).cuda().unsqueeze(0))\n",
    "print(\"\\n\")\n",
    "\n",
    "topk_vals, topk_idx = logits[0, mask_pos].topk(10, largest=True)\n",
    "probs = logits[0, mask_pos].softmax(dim=0)\n",
    "for val, idx in zip(topk_vals, topk_idx):\n",
    "    val, idx = val.item(), idx.item()\n",
    "    word_piece = sp.IdToPiece(idx)\n",
    "    print(f\"{word_piece}\\t{probs[idx].item():.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
